package com.github.defense.starter.gateway.filter;

import com.github.defense.common.request.DefenseRequest;
import com.github.defense.common.response.DefenseResponse;
import com.github.defense.common.response.Result;
import com.github.defense.plugin.base.DefensePlugin;
import com.github.defense.starter.common.plugin.DefaultDefensePluginChain;
import com.github.defense.starter.gateway.utils.WebUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.nio.charset.StandardCharsets;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

/**
 * TODO 需要兼容路由选择器 LoadBalancerClientFilter
 * 网关插件过滤器
 *
 * @Author: lettger
 * @Date: 2021/12/11 下午5:04
 */
@Slf4j
public class GatewayPluginFilter implements GlobalFilter, Ordered {


    private final List<DefensePlugin> plugins;

    public GatewayPluginFilter(final List<DefensePlugin> plugins) {
        this.plugins = loadPlugins(plugins);
    }

    /**
     * 插件排序
     *
     * @param plugins
     * @return
     */
    private static List<DefensePlugin> loadPlugins(List<DefensePlugin> plugins) {
        return plugins
                .stream()
                .sorted(Comparator.comparingInt(DefensePlugin::getOrder))
                .collect(Collectors.toList());
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        DefenseRequest defenseRequest = new DefenseRequest();
        defenseRequest.setUri(request.getURI());
        defenseRequest.setParams(exchange.getRequest().getQueryParams());
        defenseRequest.setHeaders(exchange.getRequest().getHeaders());
        defenseRequest.setMethod(request.getMethodValue());
        defenseRequest.setIp(WebUtils.getIpAddress(request));
        DefenseResponse response = new DefaultDefensePluginChain(plugins).execute(defenseRequest);
        if (Objects.nonNull(response)) {
            exchange.getResponse().getHeaders().setContentType(MediaType.APPLICATION_JSON);
            return exchange.getResponse().writeWith(Mono.just(exchange.getResponse()
                    .bufferFactory().wrap(Objects.requireNonNull(Result.ok(response).toString()).getBytes(StandardCharsets.UTF_8))));
        }
        return chain.filter(exchange);
    }


    @Override
    public int getOrder() {
        return Ordered.HIGHEST_PRECEDENCE;
    }
}
